Detecting shortcut learning for fair medical AI using shortcut testing

Machine learning (ML) holds great promise for improving healthcare, but it is critical to ensure that its use will not propagate or amplify health disparities. An important step is to characterize the (un)fairness of ML models—their tendency to perform differently across subgroups of the population—and to understand its underlying mechanisms. One potential driver of algorithmic unfairness, shortcut learning, arises when ML models base predictions on improper correlations in the training data. Diagnosing this phenomenon is difficult as sensitive attributes may be causally linked with disease. Using multitask learning, we propose a method to directly test for the presence of shortcut learning in clinical ML systems and demonstrate its application to clinical tasks in radiology and dermatology. Finally, our approach reveals instances when shortcutting is not responsible for unfairness, highlighting the need for a holistic approach to fairness mitigation in medical AI.


Supplementary Figure 2: Fairness-Pe ormance results for other CXR labels.
Separation is plo ed against AUC, with the age pe ormance of each model represented by the color as in Figure 3. Similar pa erns may be observed, whereby inducing a bias in the training dataset results in much more unfair model pe ormance, which can be ameliorated by gradient reversal (center column, green dots), or exacerbated by increasing the age representation (purple dots). In contrast, balancing the training dataset results in baseline models (orange) which are considerably fairer, and gradient reversal results in degraded model pe ormance without fu her fairness improvement. Source data are provided as a Source Data le.

Supplementary Figure 3: ShorT analysis of original NIH and subsampled datasets for Atelectasis and Abnormal (the complement of the No Finding label).
Biased datasets (middle column) result in signi cant dependence of fairness on age representation. In contrast, balanced datasets (right column), there is no such dependence. In the original dataset, there is no dependence of fairness on age representation for Atelectasis, however there is a signi cant positive correlation between fairness and age representation for the No Finding label. This implies that models which represent age more accurately (le ) tend to be fairer (closer to 0 on the y axis). This may be explained by an underuse of age information for this pa icular dataset and task. For all plots, an AUC threshold was set at 0.7, with replicates with an AUC value less than this being excluded from the correlation analysis. We chose 0.7 as the threshold as the pe ormance of baseline models was lower for Atelectasis and Abnormal labels (Supplementary Figure 1). One such replicate is not displayed on the Abnormal, Original Dataset plot, as it had a separation of > 0.1, and lies beyond the limits of the y axis. All tests are two-sided Spearman correlations. Source data are provided as a Source Data le.  Figure 5 are displayed on a single graph, with replicates pooled according to prede ned degrees of age encoding. In each degree of age encoding, we de ne its average (line) and estimate variability across models with bootstrapping (95% con dence intervals, error bars). For a given level of age encoding, models trained on the Balanced (solid line), Original (dashed line), or Biased (do ed line) datasets display vastly di erent fairness characteristics. (b) Age encoding vs Pe ormance. At the same level of age encoding, pe ormance is very similar for the Balanced and Original datasets, although the pe ormance of the Balanced dataset drops o at higher age prediction errors. The Biased dataset results in a spuriously higher AUC due to cleaner class separation (see Supplementary Figure 6). Source data are provided as a Source Data le (see Figures 2c, 3c and 3f). Figure 6: Cross-dataset pe ormance and fairness for the e usion prediction task. AUC and Separation are shown for baseline models (without an age prediction head) trained on biased, original, and balanced datasets (x axis), tested on all three datasets. In-distribution results are located on the top-le to bo om-right diagonal. Note that the best pe ormance is obtained in models trained on biased datasets, tested in-distribution; however, pe ormance is degraded for out of distribution test sets, due to sho cut learning; this increase in pe ormance is therefore spurious. Models trained on balanced datasets obtain similar pe ormance results to those trained on the original dataset. However, separation is considerably improved in models trained on balanced data. Figure 7: E ect of age head gradient scaling on age representation for the dermatology example in gure 6. ShorT models covered a range of age prediction errors, although there appeared to be a wider plateau in the middle of the range of age head scaling values, over which age prediction error was quite similar to baseline. This plateau, as well as the wide range between the "Clinical Label" and "No Information" upper error bounds, likely occurs due to the richer so labels used in this example, as well as due to the stronger dependence between age and condition probability for many (but not all) dermatological problems. Each dot represents a model trained (25 values of gradient scaling times 5 replicates), with error bars denoting 95% con dence intervals from bootstrapping examples (n= 1,925 independent patients) within a model. Source data are provided as a Source Data le.

CXR Datasets
We use two CXR datasets, NIH CXR, and CheXpe . The NIH CXR Dataset is provided by the NIH Clinical Center and is available at h ps://nihcc.app.box.com/v/ChestXray-NIHCC. For experiments, images were rst downsized to 448x448 pixels. We select the "E usion", "Atelectasis" and "No ndings" (which we repo as "Abnormal" for semantic consistency) labels provided with data as our binary outcomes, focussing on E usion.
CheXpe is available at h ps://stanfordmlgroup.github.io/competitions/chexpe /. Demographic labels are available at h ps://stanfordaimi.azurewebsites.net/datasets/192ada7c-4d43-466e-b8bb-b81992bb80cf. Following 22 , we focus on a binary distinction between Black and White patients, rather than treating race as a multi-class prediction task. Images were downsized to 448x448 pixels, and we select cardiomegaly as a binary outcome for repo ing.
The demographic information for the NIH dataset is as follows, broken down by ndings:

Model Architectures
For the medical imaging models, we employ convolutional neural networks as image embedding models, followed by multi-layer perceptrons (MLPs) as classi cation models for both clinical classi cation and age/race prediction. We use modi ed ResNet 101x3 architectures 1 pre-trained on the public Imagenet 21k dataset. Model architectures for image embedding and checkpoints are available on tenso low hub.
In the dermatology task, each clinical case includes 1 to 6 images. We average the embeddings across a clinical case before passing them to the MLP. All clinical classi cation MLPs have 2 layers, with 512 hidden units and ReLU activation, while all sensitive a ribute MLPs have 3 layers, with 512 and 256 hidden units and ReLU activation.
In array pseudocode, the architecture follows: Where the reverse_gradient is an operation that allows for gradient scaling and lambda is a hyper-parameter (positive or negative) that controls for the strength of the scaling. See the code available at h ps://github.com/google-research/google-research/tree/master/sho cut_testing for an example implementation of these di erent operations.

Hyperparameter Tuning and model selection
All models were tuned for batch size, learning rate, weight decay, and dropout in the penultimate layer before training. The same parameters were applied to models trained on each label in the CXR task.

Multitask Prediction
To adapt a single task prediction model to multitask prediction, we added a demographic (age prediction) head at the nal layer of the base model. There are no hidden layers between the feature extractor and the condition output layer. However, the demographic head itself uses two fully connected hidden layers between the gradient reversal layer and the nal age output layer, to provide the network with capacity during adversarial training.
Next, in order to approximately balance the losses between the age (mean square error) and condition (cross-entropy) heads, we down-weighted the regression loss by a factor of 100. We then tested fu her adjustments to this loss weighting using a grid search (in conjunction with a coarse gradient scaling parameter sweep). In our case, we found that simple balancing of losses was su cient.
Once the loss weighting was established, this was xed for all fu her experiments. We then swept over 25 values for scaling of the gradient updates from the demographic head, ranging from -0.1 to +0.1 (spaced exponentially). For each value of gradient scaling, 5 replicates were trained, resulting in 125 models per experiment.
For a ribute transfer experiments, the feature extractor was frozen and then a linear demographic prediction head was applied and the model retrained to predict age. Hidden layers were not required in this simpler (single task) prediction setup; we found that the addition of one or two hidden layers made no material di erence to our results.

Subsampling of training data
In order to produce datasets with a shi in the mean age between the ground truth classes, we use a logistic probability function, which de nes the probability of an example being retained as a function of the age of the patient: Where k is the slope of the function; a 0 is the midpoint of the probability function (the age at which the probability of being retained is 0.5); and m is a scale factor that increases the probability of retaining examples. This de nes a probability of retaining a positive example; for negative examples (patients without the condition), we use 1-p retain The following parameters were used to generate subsampled training sets. Since the process is stochastic, these were obtained by trial and error.

k a 0 m
Biased Balanced E usion 0.14 -0.07 50 4 Atelectasis 0.12 -0.08 50 4 Abnormal 0.14 -0.065 50 4 The training sets generated using these parameters are described below. These pe urbed datasets do not precisely match the desired shi in ages due to stochastic errors.

Significance testing when comparing ShorT across datasets
Sho cut testing (ShorT) relies on calculating the correlation between the degree of age encoding and fairness metrics. To test that the ShorT statistics di er across datasets, we pe orm permutation tests of Spearman's rho across di erent versions of the training dataset. We calculate the true di erence in correlation statistics, and compare it to an empirical null distribution of di erences. The null distribution is simulated using bootstrapping. We combine the data points from the two groups, shu e them, and randomly divide them into two groups. To calculate p-values, we compare the true di erence to this null distribution.
For CXR, we nd that di erences are highly signi cant when comparing the original and biased datasets, and the original and balanced datasets (p = 1e-8; p = 1e-4, respectively), indicating that sho cu ing happens signi cantly more with biased datasets, and signi cantly less with a balanced dataset.

Race in cardiomegaly models
To test ShorT in the context of a spurious a ribute, we apply it to race in chest x-ray analysis. Following previous work 2 , we analyze self-repo ed race in the CheXpe dataset as a binary task of predicting White and Black self-repo ed race from chest x-rays. We treat the Unce ain label for Cardiomegaly as negative.
The public validation set available for CheXpe only contains 9 individuals with self-repo ed Black race. Due to this small sample size, we instead randomly re-split the training data into new training (85%), validation (5%), and testing (10%).
This re-split has the following prope ies: We focus on the cardiomegaly prediction, as the cardiomegaly label is imbalanced for race (prevalence for White patients was 11.5%, prevalence for Black patients was 19.8%). Similar to our age models, we train models to directly predict race to estimate the upper bound of pe ormance on race (as per the AUROC on this binary prediction task). We then train models to predict both cardiomegaly and race, while sweeping over the gradient scale for the race prediction head. We set the weight of both heads as equal, as the scale of the loss is the same order of magnitude, and vary the gradient scale between -0.1 and +0.1 to match other experiments in the paper. Using an implementation inspired by Alabdulmohsin et al. 3 , we estimate fairness via equalized odds.

Dermatology Dataset and experiments
For dermatology experiments, models are trained to predict 26 skin conditions with an additional "other" category to capture the long tail of conditions, as a multiclass prediction task, as described in 4 . Our approach di ers slightly from previously published results, as we use a more modern architecture (ResNet 101x3 rather than Inception v4), and a slightly smaller training dataset. The commercial dataset used consists of teledermatology images with associated diagnoses obtained by labeling by multiple dermatologists. Unfo unately, this dataset is not available for public use.
We assess model pe ormance for a single class by using binarised metrics. For AUC, we use the prediction score of the chosen class. For separation, we de ne positive predictions to be examples where the top ranking prediction score is for the chosen class. Using top-3 selection (i.e. a positive prediction is any example where the score for the chosen class is in the top-3 scores) did not change our results.

Dermatology Dataset -Demographics
Since the dermatological dataset is not publicly available, we repo here the basic demographics of the training dataset used. We generated simulated data to assess the e cacy of ShorT. The data consisted of MNIST images 5 with the labels representing whether the number hand-wri en in the image was smaller than 5, or 5 and above. To these images, we added a small colored square at a random location. The color of the square (red or green) could be correlated with the label, and here plays the role of the sensitive a ribute A. Noise was added to the image and the square as the tasks were straigh orward. We hence obtain a data generating process that corresponds to Figure 1(b). As we control the data generating process, we are also able to generate counte actual samples, i.e. images for which the color of the square has been switched.
We implemented ShorT with Tenso low 6 v2 and Keras, using as feature extractor a small MLP of 3 dense layers with 10 units each. For gradient reversal, we added one more dense layer of size 2 before the a ribute's output layer while the label was directly predicted from the feature extractor. A ribute encoding was assessed as the ROC AUC a er training an output layer from a frozen feature extractor. Fairness was computed via equalized odds. Baseline model accuracy was between 0.8 and 0.86.
Fu her hyper-parameter selection was needed to balance the losses of the target (weight =1) and of the a ribute (search between 0.5 and 1.0). The nal value was selected as 0.75. We varied the correlation between Y and A such that a label of Y=0 was associated with a red square between 50 and 95% of the time (20 steps), while the label Y=1 was associated with a red square between 50 and 15% of the time (20 steps).
We observed that ShorT produces signi cant results for high correlations between Y and A (Supplementary Figure 8a). This corresponds to our observations with counte actuals that, given the simplicity of the task, the model does not "need" to rely on the a ribute for predictions if the correlation between A and Y is not high. Focussing on the low correlation se ing, we uniformly sampled the correlations between A and Y in the 0.4-0.6 range (n=50) and assessed the number of signi cant results for ShorT (at a threshold of p<0.05, Bonferroni corrected). We note that only 3 instances lead to signi cant p-values for ShorT (i.e. 3/50=0.06 ≈ 0.05, Supplementary Figure 8b). Finally, we focussed on the high correlation se ing and sampled uniformly in the 0.9-0.98 range for label Y=1 and in the 0.15-0.23 range for Y=0. Note that the asymmetry is needed to obtain unfairness based on equalized odds. In this case, we observe that ShorT correctly identi es sho cu ing in all instances (Supplementary Figure S8, 50/50), even a er Bonferroni correction for multiple comparisons (50/50).

Figure S8
: ShorT on simulated data. (a) Increasing the correlation between label and color in a consistent but asymmetric fashion leads to signi cant sho cu ing for high values of the correlation. Each dot represents the p-value of ShorT computed based on a di erent combination of correlations. We focus on two areas (shaded on the plot): a low correlation se ing (detailed in (b)) to assess type I error and a high correlation se ing (detailed in (c)) to assess type II error. (b) For low values of the correlation and a small asymmetry (x-axis), we obtain a uniform distribution of ShorT p-values. (c) p-values are consistently lower than p<0.05 when the asymmetry is high and the correlation between A and Y is large. All tests are two-sided Spearman correlations, with p-values corrected for multiple comparisons using Bonferroni correction.

Supplementary Discussion
In our analysis, we have chosen to preserve age as a continuous variable, using logistic regression analysis to characterize the fairness prope ies of the model. This avoids the need for arbitrary quantization of the data. However, it does assume that discrepancies, where observed, will be monotonic -with weaker pe ormance for either older or younger patients. In cases where we may expect bimodal or more complex distributions of fairness prope ies it might be more judicious to examine the model outputs rather than rely on pa icular formulations of fairness metrics. Distribution-free approaches [7][8][9] , may be considered if no pa icular form of association can be expected, although these will in general be more limited in power and interpretability. Secondly, the use of a LR model requires a binarised outcome per example, and would be unsuitable for metrics such as prediction scores (continuous) or AUC (requires a set of observations). Alternative methods 10,11 may overcome some of these limitations, at the expense of interpretability. However, our framework does not require the use of a continuous a ribute, and may be applied to binary or discrete variables, by substituting the model-based fairness metrics for conventional de nitions.